{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nSpiking Neural Networks\n=======================\n\nSpiking neural networks are not that much different than Artificial Neural\nNetworks that are currently most commonly in use. The main difference is\nthat there is an insistence that communication between units in the network\nhappens through spikes - represented as binary one or zero - and that time\nis involved.\n\nHow to define a Network\n-----------------------\n\nThe spiking neural network primitives in myelin are designed to fit in as seamlessly\nas possible into a traditional deep learning pipeline.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom myelin.torch.functional.lif import (\n    LIFParameters,\n)\n\nfrom myelin.torch.module.leaky_integrator import LICell\nfrom myelin.torch.module.lif import LIFFeedForwardCell\n\n\nclass Net(torch.nn.Module):\n    def __init__(\n            self,\n            device=\"cpu\",\n            num_channels=1,\n            feature_size=32,\n            model=\"super\",\n            dtype=torch.float,\n    ):\n        super(Net, self).__init__()\n        self.features = int(((feature_size - 4) / 2 - 4) / 2)\n\n        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)\n        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)\n        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)\n        self.lif0 = LIFFeedForwardCell(\n            (32, feature_size - 4, feature_size - 4),\n            p=LIFParameters(method=model, alpha=100.0),\n        )\n        self.lif1 = LIFFeedForwardCell(\n            (64, int((feature_size - 4) / 2) - 4, int((feature_size - 4) / 2) - 4),\n            p=LIFParameters(method=model, alpha=100.0),\n        )\n        self.lif2 = LIFFeedForwardCell(\n            (1024,), p=LIFParameters(method=model, alpha=100.0)\n        )\n        self.out = LICell(1024, 10)\n        self.device = device\n        self.dtype = dtype\n\n    def forward(self, x):\n        seq_length = x.shape[0]\n        batch_size = x.shape[1]\n\n        # specify the initial states\n        s0 = self.lif0.initial_state(batch_size, device=self.device, dtype=self.dtype)\n        s1 = self.lif1.initial_state(batch_size, device=self.device, dtype=self.dtype)\n        s2 = self.lif2.initial_state(batch_size, device=self.device, dtype=self.dtype)\n        so = self.out.initial_state(batch_size, device=self.device, dtype=self.dtype)\n\n        voltages = torch.zeros(\n            seq_length, batch_size, 10, device=self.device, dtype=self.dtype\n        )\n\n        for ts in range(seq_length):\n            z = self.conv1(x[ts, :])\n            z, s0 = self.lif0(z, s0)\n            z = torch.nn.functional.max_pool2d(z, 2, 2)\n            z = 10 * self.conv2(z)\n            z, s1 = self.lif1(z, s1)\n            z = torch.nn.functional.max_pool2d(z, 2, 2)\n            z = z.view(-1, self.features ** 2 * 64)\n            z = self.fc1(z)\n            z, s2 = self.lif2(z, s2)\n            v, so = self.out(torch.nn.functional.relu(z), so)\n            voltages[ts, :, :] = v\n        return voltages\n\n\nnet = Net()\nprint(net)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can evaluate the network we just defined on an input of size 1x32x32.\nNote that in contrast to typical spiking neural network simulators time\nis just another dimension in the input tensor here we chose to evaluate\nthe network on 16 timesteps and there is an explicit batch dimension\n(number of concurrently evaluated inputs with identical model parameters).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "timesteps = 16\nbatch_size = 1\ninput = torch.abs(torch.randn(timesteps, batch_size, 1, 32, 32))\nout = net(input)\nprint(out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since the spiking neural network is implemented as a pytorch module, we\ncan use the usual pytorch primitives for optimizing it. Note that the\nbackward computation expects a gradient for each timestep\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net.zero_grad()\nout.backward(torch.randn(timesteps, batch_size, 10))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>``myelin`` like pytorch only supports mini-batches. This means that\n    contrary to most other spiking neural network simulators ```myelin``` always\n    integrates several indepdentent sets of spiking neural networks at once.</p></div>\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}